# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" MobileNetV3 head."""

import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor


class LinearClsMobilenetv3Head(nn.Cell):
    """
    LinearClsHead architecture.

    Args:
        model_name(str): size of model.
        num_classes (int):
        Number of classes. output Number of second linear layer,Default is 1000.
        has_dropout (bool):
        Is dropout used. Default is false
    Returns:
        Tensor, output tensor.
    """

    def __init__(self, model_name, num_classes=1000, has_dropout=False, activation2="None"):
        super(LinearClsMobilenetv3Head, self).__init__()
        self.head = []
        model_cfgs = {
            "large": {
                "cls_ch_squeeze": 960,
                "cls_ch_expand": 1280,
            },
            "small": {
                "cls_ch_squeeze": 576,
                "cls_ch_expand": 1024,
            }
        }
        if model_name == 'large':
            self.input_channel = model_cfgs['large']["cls_ch_squeeze"]
            self.output_channel1 = model_cfgs['large']['cls_ch_expand']
        elif model_name == 'small':
            self.input_channel = model_cfgs['small']["cls_ch_squeeze"]
            self.output_channel1 = model_cfgs['small']['cls_ch_expand']

        head1 = ([nn.Dense(self.input_channel, self.output_channel1, has_bias=True)]
                 if not has_dropout else
                 [nn.Dropout(0.2), nn.Dense(self.input_channel, self.output_channel1, has_bias=True)])
        head2 = ([nn.Dense(self.output_channel1, num_classes, has_bias=True)])
        self.head1 = nn.SequentialCell(head1)
        self.activation1 = nn.HSwish()
        self.head2 = nn.SequentialCell(head2)
        self.head.append(self.head1)
        self.head.append(self.activation1)
        self.head.append(self.head2)
        self.head = nn.SequentialCell(self.head)
        self.need_activation = True
        if activation2 == "Sigmoid":
            self.activation = P.Sigmoid()
        elif activation2 == "Softmax":
            self.activation2 = P.Softmax()
        else:
            self.need_activation = False
        self._initialize_weights()

    def construct(self, x):
        x = self.head(x)
        if self.need_activation:
            x = self.activation(x)
        return x

    def _initialize_weights(self):
        """
        Initialize weights.

        Args:

        Returns:
            None.

        Examples:
            >>> _initialize_weights()
        """
        self.init_parameters_data()
        for _, m in self.cells_and_names():
            if isinstance(m, nn.Dense):
                m.weight.set_data(Tensor(np.random.normal(
                    0, 0.01, m.weight.data.shape).astype("float32")))
                if m.bias is not None:
                    m.bias.set_data(
                        Tensor(np.zeros(m.bias.data.shape, dtype="float32")))

    @property
    def get_head(self):
        return self.head
